package com.itrip.log.module.db.mybatis;

import io.protostuff.LinkedBuffer;
import io.protostuff.ProtostuffIOUtil;
import io.protostuff.Schema;
import io.protostuff.runtime.RuntimeSchema;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.statement.SQLDeleteStatement;
import com.alibaba.druid.sql.ast.statement.SQLInsertStatement;
import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;
import com.alibaba.druid.stat.TableStat.Name;

@Intercepts({
		@Signature(type = Executor.class, method = "update", args = { MappedStatement.class, Object.class }),
		@Signature(type = Executor.class, method = "query", args = { MappedStatement.class, Object.class,
				RowBounds.class, ResultHandler.class }) })
public class CacheInterceptor implements Interceptor {

	/***
	 * 缓存的对象实体
	 */
	public static class CacheItem {
		/** 要缓存的数据 */
		Object data;
	}

	public static interface DeepCopy {
		Object copy();
	}

	private Properties properties;
	private static final Logger LOG = LoggerFactory.getLogger(CacheInterceptor.class);
	private static final Schema<CacheItem> SCHEMA = RuntimeSchema.getSchema(CacheItem.class);

	private ConcurrentHashMap<Object, Object> CAHCE = new ConcurrentHashMap<Object, Object>();
	/** 记录table关联哪些缓存,在这些表执行更新删除插入时清除缓存 key:table value:CAHCE.key */
	private ConcurrentHashMap<String, Set<String>> TABLE_SQL = new ConcurrentHashMap<String, Set<String>>();
	/*** SQLid对应的table,避免每次解析 **/
	private ConcurrentHashMap<String, Set<String>> SQL_ID_Table = new ConcurrentHashMap<String, Set<String>>();

	public Object intercept(Invocation invocation) throws Throwable {
		Object[] args = invocation.getArgs();
		MappedStatement mst = (MappedStatement) args[0];
		if (mst.isUseCache())
			return invocation.proceed();

		BoundSql boundSql = mst.getBoundSql(args[1]);
		if (mst.getSqlCommandType() == SqlCommandType.SELECT)
			return doSelect(invocation, mst, boundSql);
		Object obj = invocation.proceed();
		removeTableRelatedCache(mst, boundSql);
		return obj;
	}

	private void removeTableRelatedCache(MappedStatement mst, BoundSql boundSql) {
		Set<String> tables = SQL_ID_Table.get(mst.getId());
		if (tables == null) {
			tables = parseTables(boundSql.getSql());
			SQL_ID_Table.put(mst.getId(), tables);
		}
		for (String table : tables) {
			Set<String> cacheKeys = TABLE_SQL.remove(table);
			if (cacheKeys == null)
				continue;
			for (String item : cacheKeys) {
				CAHCE.remove(item);
			}
			LOG.info("remove cache when table:{} update", table);
		}
	}

	private Object doSelect(Invocation invocation, MappedStatement mst, BoundSql boundSql) throws Exception {
		String key = createCacheKey(mst, boundSql);
		CacheItem item = new CacheItem();
		Object data = CAHCE.get(key);
		if (data == null) {
			Object obj = invocation.proceed();
			if (obj == null)
				return obj;
			if (!isNotSeriaClass(obj)) {
				CAHCE.put(key, obj);
				return obj;
			}
			for (String table : parseSelectTables(mst.getId(), boundSql)) {
				Set<String> set = TABLE_SQL.get(table);
				if (set == null) {
					TABLE_SQL.putIfAbsent(table, new HashSet<String>());
				}
				TABLE_SQL.get(table).add(key);
			}
			item.data = obj;
			LinkedBuffer allocate = LinkedBuffer.allocate(256);
			CAHCE.put(key, ProtostuffIOUtil.toByteArray(item, SCHEMA, allocate));
			return obj;
		} else if (data instanceof byte[]) {
			ProtostuffIOUtil.mergeFrom((byte[]) data, item, SCHEMA);
		} else if (data instanceof DeepCopy) {
			item.data = ((DeepCopy) data).copy();
		} else {
			item.data = data;
		}
		LOG.info("hit cache:{}", boundSql);
		return item.data;
	}

	private String createCacheKey(MappedStatement mst, BoundSql bound) {
		StringBuilder key = new StringBuilder(mst.getId());
		Object pObject = bound.getParameterObject();
		Configuration cfg = mst.getConfiguration();
		List<ParameterMapping> mappings = bound.getParameterMappings();
		if (mappings.size() > 0 && pObject != null) {
			TypeHandlerRegistry registry = cfg.getTypeHandlerRegistry();
			if (registry.hasTypeHandler(pObject.getClass())) {
				return key.append(pObject).toString();
			}

			MetaObject meta = cfg.newMetaObject(pObject);
			for (ParameterMapping param : mappings) {
				String pname = param.getProperty();
				if (meta.hasGetter(pname)) {
					Object value = meta.getValue(pname);
					key.append(value);
				} else if (bound.hasAdditionalParameter(pname)) {
					Object value = bound.getAdditionalParameter(pname);
					key.append(value);
				}
			}
		}
		return key.toString();
	}

	/***
	 * 判断是否需要序列化,不序列化的对象直接put
	 * 
	 * @param data
	 * @return
	 */
	private boolean isNotSeriaClass(Object data) {
		Class<?> cls = data.getClass();
		return cls.isPrimitive() || cls == DeepCopy.class || (cls == Map.class && ((Map<?, ?>) data).isEmpty())
				|| (cls == Collection.class && ((Collection<?>) data).isEmpty())
				|| (cls == Object[].class && ((Object[]) data).length == 0);
	}

	/**
	 * 获取sql关联的table的name
	 * 
	 * @param regx
	 * @param srcSql
	 * @return
	 */
	private Set<String> parseTables(String srcSql) {
		Set<String> tables = new HashSet<String>();
		SQLStatementParser parser = new MySqlStatementParser(srcSql);
		for (SQLStatement stmt : parser.parseStatementList()) {
			if (stmt instanceof SQLInsertStatement) {
				SQLInsertStatement tmp = (SQLInsertStatement) stmt;
				tables.add(tmp.getTableName().getSimpleName());
			} else if (stmt instanceof SQLUpdateStatement) {
				SQLUpdateStatement tmp = (SQLUpdateStatement) stmt;
				tables.add(tmp.getTableName().getSimpleName());
			} else if (stmt instanceof SQLDeleteStatement) {
				SQLDeleteStatement tmp = (SQLDeleteStatement) stmt;
				tables.add(tmp.getTableName().getSimpleName());
			}
		}
		return tables;
	}

	private Set<String> parseSelectTables(String id, BoundSql boundSql) {
		Set<String> tables = SQL_ID_Table.get(id);
		if (tables != null)
			return tables;
		tables = new HashSet<String>();
		SchemaStatVisitor c = new SchemaStatVisitor();
		SQLStatementParser parser = new MySqlStatementParser(boundSql.getSql());
		for (SQLStatement stmt : parser.parseStatementList()) {
			stmt.accept(c);
			for (Name name : c.getTables().keySet()) {
				tables.add(name.getName().toLowerCase());
			}
		}
		SQL_ID_Table.put(id, tables);
		return tables;
	}

	public Object plugin(Object target) {
		return Plugin.wrap(target, this);
	}

	public void setProperties(Properties prop) {
		properties = prop;
	}

	public Properties getProperties() {
		return properties;
	}
}
